{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "# Water Detection with Sentinel-2 Imagery\n",
    "\n",
    "[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/opengeos/geoai/blob/main/docs/examples/water_detection_s2.ipynb)\n",
    "\n",
    "This notebook demonstrates how to train semantic segmentation models for water detection using Sentinel-2 imagery.\n",
    "\n",
    "## Install packages\n",
    "\n",
    "To use the new functionality, ensure the required packages are installed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %pip install geoai-py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "## Import libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import geoai"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "## Download sample data\n",
    "\n",
    "We'll use the [Earth Surface Water Dataset](https://zenodo.org/records/5205674#.Y4iEFezP1hE) from Zenodo. Credits to the author (Xin Luo) of the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "url = \"https://huggingface.co/datasets/giswqs/geospatial/resolve/main/dset-s2.zip\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = geoai.download_file(url)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "images_dir = f\"{data_dir}/dset-s2/tra_scene\"\n",
    "masks_dir = f\"{data_dir}/dset-s2/tra_truth\"\n",
    "tiles_dir = f\"{data_dir}/dset-s2/tiles\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create training data\n",
    "\n",
    "We'll create the same training tiles as before."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result = geoai.export_geotiff_tiles_batch(\n",
    "    images_folder=images_dir,\n",
    "    masks_folder=masks_dir,\n",
    "    output_folder=tiles_dir,\n",
    "    tile_size=512,\n",
    "    stride=128,\n",
    "    quiet=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "## Train semantic segmentation model\n",
    "\n",
    "Now we'll train a semantic segmentation model using the new `train_segmentation_model` function. This function supports various architectures from `segmentation-models-pytorch`:\n",
    "\n",
    "- **Architectures**: `unet`, `unetplusplus` `deeplabv3`, `deeplabv3plus`, `fpn`, `pspnet`, `linknet`, `manet`\n",
    "- **Encoders**: `resnet34`, `resnet50`, `efficientnet-b0`, `mobilenet_v2`, etc.\n",
    "\n",
    "For more details, please refer to the [segmentation-models-pytorch documentation](https://smp.readthedocs.io/en/latest/models.html).\n",
    "\n",
    "Let's train the module using U-Net with ResNet34 encoder:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test train_segmentation_model with automatic size detection\n",
    "geoai.train_segmentation_model(\n",
    "    images_dir=f\"{tiles_dir}/images\",\n",
    "    labels_dir=f\"{tiles_dir}/masks\",\n",
    "    output_dir=f\"{tiles_dir}/unet_models\",\n",
    "    architecture=\"unet\",\n",
    "    encoder_name=\"resnet34\",\n",
    "    encoder_weights=\"imagenet\",\n",
    "    num_channels=6,\n",
    "    num_classes=2,  # background and water\n",
    "    batch_size=8,\n",
    "    num_epochs=50,\n",
    "    learning_rate=0.001,\n",
    "    val_split=0.2,\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate the model\n",
    "\n",
    "Let's examine the training curves and model performance:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geoai.plot_performance_metrics(\n",
    "    history_path=f\"{tiles_dir}/unet_models/training_history.pth\",\n",
    "    figsize=(15, 5),\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![image](https://github.com/user-attachments/assets/61f675a7-ee67-4650-81c0-f754fe681f4d)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "images_dir = f\"{data_dir}/dset-s2/val_scene\"\n",
    "masks_dir = f\"{data_dir}/dset-s2/val_truth\"\n",
    "predictions_dir = f\"{data_dir}/dset-s2/predictions\"\n",
    "model_path = f\"{tiles_dir}/unet_models/best_model.pth\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geoai.semantic_segmentation_batch(\n",
    "    input_dir=images_dir,\n",
    "    output_dir=predictions_dir,\n",
    "    model_path=model_path,\n",
    "    architecture=\"unet\",\n",
    "    encoder_name=\"resnet34\",\n",
    "    num_channels=6,\n",
    "    num_classes=2,\n",
    "    window_size=512,\n",
    "    overlap=256,\n",
    "    batch_size=8,\n",
    "    quiet=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_image_path = (\n",
    "    f\"{data_dir}/dset-s2/val_scene/S2A_L2A_20190318_N0211_R061_6Bands_S2.tif\"\n",
    ")\n",
    "ground_truth_path = (\n",
    "    f\"{data_dir}/dset-s2/val_truth/S2A_L2A_20190318_N0211_R061_S2_Truth.tif\"\n",
    ")\n",
    "prediction_path = (\n",
    "    f\"{data_dir}/dset-s2/predictions/S2A_L2A_20190318_N0211_R061_6Bands_S2_mask.tif\"\n",
    ")\n",
    "save_path = f\"{data_dir}/dset-s2/S2A_L2A_20190318_N0211_R061_6Bands_S2_comparison.png\"\n",
    "\n",
    "fig = geoai.plot_prediction_comparison(\n",
    "    original_image=test_image_path,\n",
    "    prediction_image=prediction_path,\n",
    "    ground_truth_image=ground_truth_path,\n",
    "    titles=[\"Original\", \"Prediction\", \"Ground Truth\"],\n",
    "    figsize=(15, 5),\n",
    "    save_path=save_path,\n",
    "    show_plot=True,\n",
    "    indexes=[5, 4, 3],\n",
    "    divider=5000,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![image](https://github.com/user-attachments/assets/53601ed7-2bd6-4e7e-b369-4d7bfc2ce120)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "geo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
